Skip to content

Conversation

@tillahoffmann
Copy link
Collaborator

@tillahoffmann tillahoffmann commented Dec 4, 2025

This PR adds an initial_value argument to the GaussianStateSpace distribution as suggested in #2098.

As part of this change, I added an optional constraint. I'm a bit torn on whether that's the right choice, and we could instead promote 0 to the right shape. However, that would make evaluating the mean of the distribution relatively inefficient when there is no initial value: We'd still scan over the sequence even though we should really just return zeros (although maybe jax.jit amortizes that?). Open to suggestions.

@tillahoffmann tillahoffmann added enhancement New feature or request question Further information is requested labels Dec 4, 2025
@tillahoffmann tillahoffmann force-pushed the init-gaussian-state-space branch from a20053c to 70ba63c Compare December 4, 2025 02:41
@javier-garcia-tilburg
Copy link

javier-garcia-tilburg commented Dec 6, 2025

I was playing around with this simple example and I like it 👍

import jax
import jax.numpy as jnp
import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS

def model(data):
    reversion_speed = numpyro.sample("k",dist.HalfNormal(1.0))
    sigma = numpyro.sample("sigma",dist.HalfNormal(1.0))
    x0 = numpyro.sample(
        "x0",
        dist.Normal(0.0, jnp.divide(jnp.power(sigma, 2), 2 * reversion_speed)),
        obs=data[0]
    )

    numpyro.sample(
        "x",
        dist.GaussianStateSpace(
            num_steps=jnp.shape(data)[-1]-1,
            transition_matrix=jnp.array([[ jnp.exp(- jnp.multiply(reversion_speed, 1.0)) - 1 ]]),
            covariance_matrix=jnp.array([[ jnp.multiply(jnp.divide(jnp.power(sigma, 2), 2), jnp.divide(1 - jnp.exp(- jnp.multiply(reversion_speed, 1.0)), reversion_speed)) ]]),
            initial_value = jnp.stack([x0])
        ),
        obs=jnp.stack([data[1:]], axis=-1)
    )

mcmc = MCMC(
    NUTS(
        model=model
    ), 
    num_warmup=500, 
    num_samples=1_000
)
mcmc.run(
    rng_key=jax.random.PRNGKey(2),
    data=(
        lambda reversion_speed, sigma, std_norm: jnp.concatenate([
            jnp.array([jnp.sqrt(jnp.divide(jnp.power(sigma, 2), 2 * reversion_speed)) * std_norm[0]]),
            jax.lax.scan(
                lambda y, x: (jnp.multiply(y, jnp.exp(- jnp.multiply(reversion_speed, 1.0)) - 1) + x * jnp.sqrt(jnp.multiply(jnp.divide(jnp.power(sigma, 2), 2), jnp.divide(1 - jnp.exp(- jnp.multiply(reversion_speed, 1.0)), reversion_speed))),) * 2,
                init=jnp.sqrt(jnp.divide(jnp.power(sigma, 2), 2 * reversion_speed)) * std_norm[0],
                xs=std_norm[1:]
            )[1]
        ])
    )(
        0.1, 0.5, jax.random.normal(key=jax.random.PRNGKey(10), shape=(20,))
    )
)
mcmc.print_summary()

# The mean of the base distribution is zero and it has the right shape.
return self.base_dist.mean
# If there's no initial value, the mean is zero (base distribution mean).
if self.initial_value is None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how about checking self._initial_value is None and setting the property self.initial_value to zero if self._initial_value is None

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

awaiting response enhancement New feature or request question Further information is requested

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants